-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removed no_grad from solver #19
Conversation
b61b6d0
to
68f0f96
Compare
step_size=step_size if method != "dopri5" else None, | ||
time_grid=time_grid, | ||
method=method, | ||
enable_grad=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check grads are not computed without this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added test
@@ -105,6 +127,7 @@ def dummy_log_p(x: Tensor) -> Tensor: | |||
log_p0=dummy_log_p, | |||
step_size=step_size, | |||
exact_divergence=True, | |||
enable_grad=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check grads not computed without this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added test
@@ -174,16 +175,15 @@ def dynamics_func(t, states): | |||
y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device)) | |||
ode_opts = {"step_size": step_size} if step_size is not None else {} | |||
|
|||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this no_grad unnecessary previously?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was unnecessary yes.
Do the docs for the affected methods get updated with the |
Moving this PR to internal |
It was suggested that no part of the core library, such as the sampler and likelihood computation, should be wrapped in
no_grad
. This is because if the code is expected to be integrated into other people's projects, it should not enforceno_grad
and instead let the user decide whether to track the computation graph. While users can add their ownno_grad
, it is impossible for them to remove ano_grad
that has already been applied.This PR removes
no_grad
from the library.I tested it with the example notebooks and I also added a unit test to make sure we can differentiate through the ode solver and the likelihood computation.